{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 10. \tDynamic Memory Networks for Question Answering"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "I recommend you take a look at these material first."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture16-DMN-QA.pdf\n",
    "* https://arxiv.org/abs/1506.07285"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import nltk\n",
    "import random\n",
    "import numpy as np\n",
    "from collections import Counter, OrderedDict\n",
    "import nltk\n",
    "from copy import deepcopy\n",
    "import os\n",
    "import re\n",
    "import unicodedata\n",
    "flatten = lambda l: [item for sublist in l for item in sublist]\n",
    "\n",
    "from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence\n",
    "random.seed(1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "USE_CUDA = torch.cuda.is_available()\n",
    "gpus = [0]\n",
    "torch.cuda.set_device(gpus[0])\n",
    "\n",
    "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
    "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
    "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def getBatch(batch_size, train_data):\n",
    "    random.shuffle(train_data)\n",
    "    sindex = 0\n",
    "    eindex = batch_size\n",
    "    while eindex < len(train_data):\n",
    "        batch = train_data[sindex: eindex]\n",
    "        temp = eindex\n",
    "        eindex = eindex + batch_size\n",
    "        sindex = temp\n",
    "        yield batch\n",
    "    \n",
    "    if eindex >= len(train_data):\n",
    "        batch = train_data[sindex:]\n",
    "        yield batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def pad_to_batch(batch, w_to_ix): # for bAbI dataset\n",
    "    fact,q,a = list(zip(*batch))\n",
    "    max_fact = max([len(f) for f in fact])\n",
    "    max_len = max([f.size(1) for f in flatten(fact)])\n",
    "    max_q = max([qq.size(1) for qq in q])\n",
    "    max_a = max([aa.size(1) for aa in a])\n",
    "    \n",
    "    facts, fact_masks, q_p, a_p = [], [], [], []\n",
    "    for i in range(len(batch)):\n",
    "        fact_p_t = []\n",
    "        for j in range(len(fact[i])):\n",
    "            if fact[i][j].size(1) < max_len:\n",
    "                fact_p_t.append(torch.cat([fact[i][j], Variable(LongTensor([w_to_ix['<PAD>']] * (max_len - fact[i][j].size(1)))).view(1, -1)], 1))\n",
    "            else:\n",
    "                fact_p_t.append(fact[i][j])\n",
    "\n",
    "        while len(fact_p_t) < max_fact:\n",
    "            fact_p_t.append(Variable(LongTensor([w_to_ix['<PAD>']] * max_len)).view(1, -1))\n",
    "\n",
    "        fact_p_t = torch.cat(fact_p_t)\n",
    "        facts.append(fact_p_t)\n",
    "        fact_masks.append(torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))), volatile=False) for t in fact_p_t]).view(fact_p_t.size(0), -1))\n",
    "\n",
    "        if q[i].size(1) < max_q:\n",
    "            q_p.append(torch.cat([q[i], Variable(LongTensor([w_to_ix['<PAD>']] * (max_q - q[i].size(1)))).view(1, -1)], 1))\n",
    "        else:\n",
    "            q_p.append(q[i])\n",
    "\n",
    "        if a[i].size(1) < max_a:\n",
    "            a_p.append(torch.cat([a[i], Variable(LongTensor([w_to_ix['<PAD>']] * (max_a - a[i].size(1)))).view(1, -1)], 1))\n",
    "        else:\n",
    "            a_p.append(a[i])\n",
    "\n",
    "    questions = torch.cat(q_p)\n",
    "    answers = torch.cat(a_p)\n",
    "    question_masks = torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))), volatile=False) for t in questions]).view(questions.size(0), -1)\n",
    "    \n",
    "    return facts, fact_masks, questions, question_masks, answers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def prepare_sequence(seq, to_index):\n",
    "    idxs = list(map(lambda w: to_index[w] if to_index.get(w) is not None else to_index[\"<UNK>\"], seq))\n",
    "    return Variable(LongTensor(idxs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data load and Preprocessing "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### bAbI dataset(https://research.fb.com/downloads/babi/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def bAbI_data_load(path):\n",
    "    try:\n",
    "        data = open(path).readlines()\n",
    "    except:\n",
    "        print(\"Such a file does not exist at %s\".format(path))\n",
    "        return None\n",
    "    \n",
    "    data = [d[:-1] for d in data]\n",
    "    data_p = []\n",
    "    fact = []\n",
    "    qa = []\n",
    "    try:\n",
    "        for d in data:\n",
    "            index = d.split(' ')[0]\n",
    "            if index == '1':\n",
    "                fact = []\n",
    "                qa = []\n",
    "            if '?' in d:\n",
    "                temp = d.split('\\t')\n",
    "                q = temp[0].strip().replace('?', '').split(' ')[1:] + ['?']\n",
    "                a = temp[1].split() + ['</s>']\n",
    "                stemp = deepcopy(fact)\n",
    "                data_p.append([stemp, q, a])\n",
    "            else:\n",
    "                tokens = d.replace('.', '').split(' ')[1:] + ['</s>']\n",
    "                fact.append(tokens)\n",
    "    except:\n",
    "        print(\"Please check the data is right\")\n",
    "        return None\n",
    "    return data_p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_data = bAbI_data_load('../dataset/bAbI/en-10k/qa5_three-arg-relations_train.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[['Bill', 'travelled', 'to', 'the', 'office', '</s>'],\n",
       "  ['Bill', 'picked', 'up', 'the', 'football', 'there', '</s>'],\n",
       "  ['Bill', 'went', 'to', 'the', 'bedroom', '</s>'],\n",
       "  ['Bill', 'gave', 'the', 'football', 'to', 'Fred', '</s>']],\n",
       " ['What', 'did', 'Bill', 'give', 'to', 'Fred', '?'],\n",
       " ['football', '</s>']]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "fact,q,a = list(zip(*train_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "vocab = list(set(flatten(flatten(fact)) + flatten(q) + flatten(a)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "word2index={'<PAD>': 0, '<UNK>': 1, '<s>': 2, '</s>': 3}\n",
    "for vo in vocab:\n",
    "    if word2index.get(vo) is None:\n",
    "        word2index[vo] = len(word2index)\n",
    "index2word = {v:k for k, v in word2index.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "44"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(word2index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for t in train_data:\n",
    "    for i,fact in enumerate(t[0]):\n",
    "        t[0][i] = prepare_sequence(fact, word2index).view(1, -1)\n",
    "    \n",
    "    t[1] = prepare_sequence(t[1], word2index).view(1, -1)\n",
    "    t[2] = prepare_sequence(t[2], word2index).view(1, -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modeling "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../images/10.dmn-architecture.png\">\n",
    "<center>borrowed image from https://arxiv.org/pdf/1506.07285.pdf</center>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class DMN(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size, dropout_p=0.1):\n",
    "        super(DMN, self).__init__()\n",
    "        \n",
    "        self.hidden_size = hidden_size\n",
    "        self.embed = nn.Embedding(input_size, hidden_size, padding_idx=0) #sparse=True)\n",
    "        self.input_gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
    "        self.question_gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
    "        \n",
    "        self.gate = nn.Sequential(\n",
    "                            nn.Linear(hidden_size * 4, hidden_size),\n",
    "                            nn.Tanh(),\n",
    "                            nn.Linear(hidden_size, 1),\n",
    "                            nn.Sigmoid()\n",
    "                        )\n",
    "        \n",
    "        self.attention_grucell =  nn.GRUCell(hidden_size, hidden_size)\n",
    "        self.memory_grucell = nn.GRUCell(hidden_size, hidden_size)\n",
    "        self.answer_grucell = nn.GRUCell(hidden_size * 2, hidden_size)\n",
    "        self.answer_fc = nn.Linear(hidden_size, output_size)\n",
    "        \n",
    "        self.dropout = nn.Dropout(dropout_p)\n",
    "        \n",
    "    def init_hidden(self, inputs):\n",
    "        hidden = Variable(torch.zeros(1, inputs.size(0), self.hidden_size))\n",
    "        return hidden.cuda() if USE_CUDA else hidden\n",
    "    \n",
    "    def init_weight(self):\n",
    "        nn.init.xavier_uniform(self.embed.state_dict()['weight'])\n",
    "        \n",
    "        for name, param in self.input_gru.state_dict().items():\n",
    "            if 'weight' in name: nn.init.xavier_normal(param)\n",
    "        for name, param in self.question_gru.state_dict().items():\n",
    "            if 'weight' in name: nn.init.xavier_normal(param)\n",
    "        for name, param in self.gate.state_dict().items():\n",
    "            if 'weight' in name: nn.init.xavier_normal(param)\n",
    "        for name, param in self.attention_grucell.state_dict().items():\n",
    "            if 'weight' in name: nn.init.xavier_normal(param)\n",
    "        for name, param in self.memory_grucell.state_dict().items():\n",
    "            if 'weight' in name: nn.init.xavier_normal(param)\n",
    "        for name, param in self.answer_grucell.state_dict().items():\n",
    "            if 'weight' in name: nn.init.xavier_normal(param)\n",
    "        \n",
    "        nn.init.xavier_normal(self.answer_fc.state_dict()['weight'])\n",
    "        self.answer_fc.bias.data.fill_(0)\n",
    "        \n",
    "    def forward(self, facts, fact_masks, questions, question_masks, num_decode, episodes=3, is_training=False):\n",
    "        \"\"\"\n",
    "        facts : (B,T_C,T_I) / LongTensor in List # batch_size, num_of_facts, length_of_each_fact(padded)\n",
    "        fact_masks : (B,T_C,T_I) / ByteTensor in List # batch_size, num_of_facts, length_of_each_fact(padded)\n",
    "        questions : (B,T_Q) / LongTensor # batch_size, question_length\n",
    "        question_masks : (B,T_Q) / ByteTensor # batch_size, question_length\n",
    "        \"\"\"\n",
    "        # Input Module\n",
    "        C = [] # encoded facts\n",
    "        for fact, fact_mask in zip(facts, fact_masks):\n",
    "            embeds = self.embed(fact)\n",
    "            if is_training:\n",
    "                embeds = self.dropout(embeds)\n",
    "            hidden = self.init_hidden(fact)\n",
    "            outputs, hidden = self.input_gru(embeds, hidden)\n",
    "            real_hidden = []\n",
    "\n",
    "            for i, o in enumerate(outputs): # B,T,D\n",
    "                real_length = fact_mask[i].data.tolist().count(0) \n",
    "                real_hidden.append(o[real_length - 1])\n",
    "\n",
    "            C.append(torch.cat(real_hidden).view(fact.size(0), -1).unsqueeze(0))\n",
    "        \n",
    "        encoded_facts = torch.cat(C) # B,T_C,D\n",
    "        \n",
    "        # Question Module\n",
    "        embeds = self.embed(questions)\n",
    "        if is_training:\n",
    "            embeds = self.dropout(embeds)\n",
    "        hidden = self.init_hidden(questions)\n",
    "        outputs, hidden = self.question_gru(embeds, hidden)\n",
    "        \n",
    "        if isinstance(question_masks, torch.autograd.variable.Variable):\n",
    "            real_question = []\n",
    "            for i, o in enumerate(outputs): # B,T,D\n",
    "                real_length = question_masks[i].data.tolist().count(0) \n",
    "                real_question.append(o[real_length - 1])\n",
    "            encoded_question = torch.cat(real_question).view(questions.size(0), -1) # B,D\n",
    "        else: # for inference mode\n",
    "            encoded_question = hidden.squeeze(0) # B,D\n",
    "            \n",
    "        # Episodic Memory Module\n",
    "        memory = encoded_question\n",
    "        T_C = encoded_facts.size(1)\n",
    "        B = encoded_facts.size(0)\n",
    "        for i in range(episodes):\n",
    "            hidden = self.init_hidden(encoded_facts.transpose(0, 1)[0]).squeeze(0) # B,D\n",
    "            for t in range(T_C):\n",
    "                #TODO: fact masking\n",
    "                #TODO: gate function => softmax\n",
    "                z = torch.cat([\n",
    "                                    encoded_facts.transpose(0, 1)[t] * encoded_question, # B,D , element-wise product\n",
    "                                    encoded_facts.transpose(0, 1)[t] * memory, # B,D , element-wise product\n",
    "                                    torch.abs(encoded_facts.transpose(0,1)[t] - encoded_question), # B,D\n",
    "                                    torch.abs(encoded_facts.transpose(0,1)[t] - memory) # B,D\n",
    "                                ], 1)\n",
    "                g_t = self.gate(z) # B,1 scalar\n",
    "                hidden = g_t * self.attention_grucell(encoded_facts.transpose(0, 1)[t], hidden) + (1 - g_t) * hidden\n",
    "                \n",
    "            e = hidden\n",
    "            memory = self.memory_grucell(e, memory)\n",
    "        \n",
    "        # Answer Module\n",
    "        answer_hidden = memory\n",
    "        start_decode = Variable(LongTensor([[word2index['<s>']] * memory.size(0)])).transpose(0, 1)\n",
    "        y_t_1 = self.embed(start_decode).squeeze(1) # B,D\n",
    "        \n",
    "        decodes = []\n",
    "        for t in range(num_decode):\n",
    "            answer_hidden = self.answer_grucell(torch.cat([y_t_1, encoded_question], 1), answer_hidden)\n",
    "            decodes.append(F.log_softmax(self.answer_fc(answer_hidden),1))\n",
    "        return torch.cat(decodes, 1).view(B * num_decode, -1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It takes for a while if you use just cpu."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "HIDDEN_SIZE = 80\n",
    "BATCH_SIZE = 64\n",
    "LR = 0.001\n",
    "EPOCH = 50\n",
    "NUM_EPISODE = 3\n",
    "EARLY_STOPPING = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = DMN(len(word2index), HIDDEN_SIZE, len(word2index))\n",
    "model.init_weight()\n",
    "if USE_CUDA:\n",
    "    model = model.cuda()\n",
    "\n",
    "loss_function = nn.CrossEntropyLoss(ignore_index=0)\n",
    "optimizer = optim.Adam(model.parameters(), lr=LR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0/50] mean_loss : 3.86\n",
      "[0/50] mean_loss : 1.32\n",
      "[1/50] mean_loss : 0.68\n",
      "[1/50] mean_loss : 0.65\n",
      "[2/50] mean_loss : 0.62\n",
      "[2/50] mean_loss : 0.65\n",
      "[3/50] mean_loss : 0.65\n",
      "[3/50] mean_loss : 0.64\n",
      "[4/50] mean_loss : 0.60\n",
      "[4/50] mean_loss : 0.62\n",
      "[5/50] mean_loss : 0.63\n",
      "[5/50] mean_loss : 0.61\n",
      "[6/50] mean_loss : 0.60\n",
      "[6/50] mean_loss : 0.61\n",
      "[7/50] mean_loss : 0.63\n",
      "[7/50] mean_loss : 0.60\n",
      "[8/50] mean_loss : 0.62\n",
      "[8/50] mean_loss : 0.60\n",
      "[9/50] mean_loss : 0.58\n",
      "[9/50] mean_loss : 0.60\n",
      "[10/50] mean_loss : 0.60\n",
      "[10/50] mean_loss : 0.60\n",
      "[11/50] mean_loss : 0.62\n",
      "[11/50] mean_loss : 0.60\n",
      "[12/50] mean_loss : 0.61\n",
      "[12/50] mean_loss : 0.60\n",
      "[13/50] mean_loss : 0.57\n",
      "[13/50] mean_loss : 0.60\n",
      "[14/50] mean_loss : 0.59\n",
      "[14/50] mean_loss : 0.60\n",
      "[15/50] mean_loss : 0.61\n",
      "[15/50] mean_loss : 0.60\n",
      "[16/50] mean_loss : 0.59\n",
      "[16/50] mean_loss : 0.60\n",
      "[17/50] mean_loss : 0.59\n",
      "[17/50] mean_loss : 0.60\n",
      "[18/50] mean_loss : 0.51\n",
      "[18/50] mean_loss : 0.50\n",
      "[19/50] mean_loss : 0.44\n",
      "[19/50] mean_loss : 0.37\n",
      "[20/50] mean_loss : 0.30\n",
      "[20/50] mean_loss : 0.33\n",
      "[21/50] mean_loss : 0.31\n",
      "[21/50] mean_loss : 0.31\n",
      "[22/50] mean_loss : 0.29\n",
      "[22/50] mean_loss : 0.31\n",
      "[23/50] mean_loss : 0.29\n",
      "[23/50] mean_loss : 0.31\n",
      "[24/50] mean_loss : 0.24\n",
      "[24/50] mean_loss : 0.31\n",
      "[25/50] mean_loss : 0.30\n",
      "[25/50] mean_loss : 0.30\n",
      "[26/50] mean_loss : 0.14\n",
      "[26/50] mean_loss : 0.16\n",
      "[27/50] mean_loss : 0.12\n",
      "[27/50] mean_loss : 0.15\n",
      "[28/50] mean_loss : 0.18\n",
      "[28/50] mean_loss : 0.14\n",
      "[29/50] mean_loss : 0.12\n",
      "[29/50] mean_loss : 0.14\n",
      "[30/50] mean_loss : 0.14\n",
      "[30/50] mean_loss : 0.14\n",
      "[31/50] mean_loss : 0.13\n",
      "[31/50] mean_loss : 0.14\n",
      "[32/50] mean_loss : 0.11\n",
      "[32/50] mean_loss : 0.13\n",
      "[33/50] mean_loss : 0.08\n",
      "[33/50] mean_loss : 0.06\n",
      "[34/50] mean_loss : 0.01\n",
      "[34/50] mean_loss : 0.03\n",
      "[35/50] mean_loss : 0.01\n",
      "Early Stopping!\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(EPOCH):\n",
    "    losses = []\n",
    "    if EARLY_STOPPING: \n",
    "        break\n",
    "        \n",
    "    for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n",
    "        facts, fact_masks, questions, question_masks, answers = pad_to_batch(batch, word2index)\n",
    "        \n",
    "        model.zero_grad()\n",
    "        pred = model(facts, fact_masks, questions, question_masks, answers.size(1), NUM_EPISODE, True)\n",
    "        loss = loss_function(pred, answers.view(-1))\n",
    "        losses.append(loss.data.tolist()[0])\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if i % 100 == 0:\n",
    "            print(\"[%d/%d] mean_loss : %0.2f\" %(epoch, EPOCH, np.mean(losses)))\n",
    "            \n",
    "            if np.mean(losses) < 0.01:\n",
    "                EARLY_STOPPING = True\n",
    "                print(\"Early Stopping!\")\n",
    "                break\n",
    "            losses = []"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def pad_to_fact(fact, x_to_ix): # this is for inference\n",
    "    \n",
    "    max_x = max([s.size(1) for s in fact])\n",
    "    x_p = []\n",
    "    for i in range(len(fact)):\n",
    "        if fact[i].size(1) < max_x:\n",
    "            x_p.append(torch.cat([fact[i], Variable(LongTensor([x_to_ix['<PAD>']] * (max_x - fact[i].size(1)))).view(1, -1)], 1))\n",
    "        else:\n",
    "            x_p.append(fact[i])\n",
    "        \n",
    "    fact = torch.cat(x_p)\n",
    "    fact_mask = torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))), volatile=False) for t in fact]).view(fact.size(0), -1)\n",
    "    return fact, fact_mask"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare Test data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "test_data = bAbI_data_load('../dataset/bAbI/en-10k/qa5_three-arg-relations_test.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for t in test_data:\n",
    "    for i, fact in enumerate(t[0]):\n",
    "        t[0][i] = prepare_sequence(fact, word2index).view(1, -1)\n",
    "    \n",
    "    t[1] = prepare_sequence(t[1], word2index).view(1, -1)\n",
    "    t[2] = prepare_sequence(t[2], word2index).view(1, -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Accuracy "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "accuracy = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "97.39999999999999\n"
     ]
    }
   ],
   "source": [
    "for t in test_data:\n",
    "    fact, fact_mask = pad_to_fact(t[0], word2index)\n",
    "    question = t[1]\n",
    "    question_mask = Variable(ByteTensor([0] * t[1].size(1)), volatile=False).unsqueeze(0)\n",
    "    answer = t[2].squeeze(0)\n",
    "    \n",
    "    model.zero_grad()\n",
    "    pred = model([fact], [fact_mask], question, question_mask, answer.size(0), NUM_EPISODE)\n",
    "    if pred.max(1)[1].data.tolist() == answer.data.tolist():\n",
    "        accuracy += 1\n",
    "\n",
    "print(accuracy/len(test_data) * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sample test result "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Facts : \n",
      "Bill went back to the bedroom </s>\n",
      "Mary went to the office </s> <PAD>\n",
      "Jeff journeyed to the kitchen </s> <PAD>\n",
      "Fred journeyed to the kitchen </s> <PAD>\n",
      "Fred got the milk there </s> <PAD>\n",
      "Fred handed the milk to Jeff </s>\n",
      "Jeff passed the milk to Fred </s>\n",
      "Fred gave the milk to Jeff </s>\n",
      "\n",
      "Question :  Who received the milk ?\n",
      "\n",
      "Answer :  Jeff </s>\n",
      "Prediction :  Jeff </s>\n"
     ]
    }
   ],
   "source": [
    "t = random.choice(test_data)\n",
    "fact, fact_mask = pad_to_fact(t[0], word2index)\n",
    "question = t[1]\n",
    "question_mask = Variable(ByteTensor([0] * t[1].size(1)), volatile=False).unsqueeze(0)\n",
    "answer = t[2].squeeze(0)\n",
    "\n",
    "model.zero_grad()\n",
    "pred = model([fact], [fact_mask], question, question_mask, answer.size(0), NUM_EPISODE)\n",
    "\n",
    "print(\"Facts : \")\n",
    "print('\\n'.join([' '.join(list(map(lambda x: index2word[x],f))) for f in fact.data.tolist()]))\n",
    "print(\"\")\n",
    "print(\"Question : \",' '.join(list(map(lambda x: index2word[x], question.data.tolist()[0]))))\n",
    "print(\"\")\n",
    "print(\"Answer : \",' '.join(list(map(lambda x: index2word[x], answer.data.tolist()))))\n",
    "print(\"Prediction : \",' '.join(list(map(lambda x: index2word[x], pred.max(1)[1].data.tolist()))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Further topics "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* <a href=\"https://arxiv.org/pdf/1603.01417.pdf\">Dynamic Memory Networks for Visual and Textual Question Answering(DMN+)</a>\n",
    "* <a href=\"https://github.com/dandelin/Dynamic-memory-networks-plus-Pytorch\">DMN+ Pytorch implementation</a>\n",
    "* <a href=\"https://arxiv.org/pdf/1611.01604\">Dynamic Coattention Networks For Question Answering</a>\n",
    "* <a href=\"https://arxiv.org/pdf/1711.00106\">DCN+: Mixed Objective and Deep Residual Coattention for Question Answering</a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
